{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e799766a-fa29-4982-bfaa-34d5ed8631c8",
   "metadata": {},
   "source": [
    "# Pose classification model optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7acc177a-a0c3-4bad-9c32-9f956149fb6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Physical GPUs, 1 Logical GPUs\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.python.compiler.tensorrt import trt_convert as trt\n",
    "from tensorflow import keras\n",
    "from keras.utils.data_utils import get_file\n",
    "from tensorflow.python.saved_model import tag_constants\n",
    "\n",
    "import json\n",
    "import numpy as np\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "from pose_classification_kit.datasets import bodyDataset, BODY18, dataAugmentation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0037c884-b968-475d-add4-f25d9e9d0558",
   "metadata": {},
   "source": [
    "## Dataset & Keras model import"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1e483d21-94d1-4691-9f56-15c32e44ec76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x_train: (7040, 18, 2)\n",
      "y_train: (7040,)\n",
      "y_train_onehot: (7040, 20)\n",
      "x_test: (3000, 18, 2)\n",
      "y_test: (3000,)\n",
      "y_test_onehot: (3000, 20)\n",
      "labels: (20,)\n"
     ]
    }
   ],
   "source": [
    "for key,val in bodyDataset(testSplit = .3, bodyModel = BODY18).items():\n",
    "    val = np.array(val)\n",
    "    exec(key + '=val')\n",
    "    print(key+':', val.shape)\n",
    "\n",
    "x_test_aug, y_test_aug = dataAugmentation(\n",
    "    x_test, y_test_onehot,\n",
    "    augmentation_ratio = 1.,\n",
    "    remove_specific_keypoints = np.where(np.isin(BODY18.mapping,[\n",
    "        \"left_knee\", \"right_knee\", \"left_ankle\", \"right_ankle\", \"left_hip\", \"right_hip\"\n",
    "        ]))[0],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "724b60cb-2441-4625-827f-64724a38f6fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_keras = keras.models.load_model('./Robust_BODY18.h5')\n",
    "\n",
    "with open('Robust_BODY18_Info.json') as f:\n",
    "    model_labels = np.array(json.load(f)['labels'])\n",
    "assert np.array_equiv(model_labels, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ac28fd0e-d8b7-450c-b04d-27e22a21ecb8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 315/3000 [01:15<10:41,  4.18it/s] \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-c362a8b1ab19>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     12\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnb_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdur_tot\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnb_samples\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mduration\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbenchmark_keras\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_keras\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test_onehot\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0maccuracy_aug\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mduration_aug\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbenchmark_keras\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_keras\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_test_aug\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test_aug\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-4-c362a8b1ab19>\u001b[0m in \u001b[0;36mbenchmark_keras\u001b[0;34m(keras_model, ds_x, ds_y)\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mds_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mds_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnb_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m         \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m         \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkeras_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m         \u001b[0mdur_tot\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m   1704\u001b[0m           \u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1705\u001b[0m           \u001b[0mmodel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1706\u001b[0;31m           steps_per_execution=self._steps_per_execution)\n\u001b[0m\u001b[1;32m   1707\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1708\u001b[0m       \u001b[0;31m# Container that configures and calls `tf.keras.Callback`s.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py\u001b[0m in \u001b[0;36mget_data_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1362\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"model\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"_cluster_coordinator\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1363\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0m_ClusterCoordinatorDataHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1364\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mDataHandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1365\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1366\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, x, y, sample_weight, batch_size, steps_per_epoch, initial_epoch, epochs, shuffle, class_weight, max_queue_size, workers, use_multiprocessing, model, steps_per_execution, distribute)\u001b[0m\n\u001b[1;32m   1164\u001b[0m         \u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1165\u001b[0m         \u001b[0mdistribution_strategy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mds_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_strategy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1166\u001b[0;31m         model=model)\n\u001b[0m\u001b[1;32m   1167\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1168\u001b[0m     \u001b[0mstrategy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mds_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_strategy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/data_adapter.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, x, y, sample_weights, sample_weight_modes, batch_size, epochs, steps, shuffle, **kwargs)\u001b[0m\n\u001b[1;32m    335\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mflat_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    336\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 337\u001b[0;31m     \u001b[0mindices_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mindices_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflat_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mslice_batch_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    338\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    339\u001b[0m     \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mslice_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindices_dataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36mflat_map\u001b[0;34m(self, map_func)\u001b[0m\n\u001b[1;32m   1955\u001b[0m       \u001b[0mDataset\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mA\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mDataset\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1956\u001b[0m     \"\"\"\n\u001b[0;32m-> 1957\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mFlatMapDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1958\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1959\u001b[0m   def interleave(self,\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, input_dataset, map_func)\u001b[0m\n\u001b[1;32m   4563\u001b[0m     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_input_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4564\u001b[0m     self._map_func = StructuredFunctionWrapper(\n\u001b[0;32m-> 4565\u001b[0;31m         map_func, self._transformation_name(), dataset=input_dataset)\n\u001b[0m\u001b[1;32m   4566\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_func\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_structure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDatasetSpec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4567\u001b[0m       raise TypeError(\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)\u001b[0m\n\u001b[1;32m   3706\u001b[0m               \u001b[0;34m\"To force eager execution of tf.data functions, please use \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3707\u001b[0m               \"`tf.data.experimental.enable.debug_mode()`.\")\n\u001b[0;32m-> 3708\u001b[0;31m         \u001b[0mfn_factory\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_tf_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdefun_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   3709\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3710\u001b[0m     \u001b[0mresource_tracker\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtracking\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mResourceTracker\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py\u001b[0m in \u001b[0;36mtrace_tf_function\u001b[0;34m(defun_kwargs)\u001b[0m\n\u001b[1;32m   3683\u001b[0m               self._input_structure),\n\u001b[1;32m   3684\u001b[0m           \u001b[0mautograph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3685\u001b[0;31m           attributes=defun_kwargs)\n\u001b[0m\u001b[1;32m   3686\u001b[0m       \u001b[0;32mdef\u001b[0m \u001b[0mwrapped_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=missing-docstring\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3687\u001b[0m         \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwrapper_helper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py\u001b[0m in \u001b[0;36mdecorated\u001b[0;34m(function)\u001b[0m\n\u001b[1;32m   3884\u001b[0m             \u001b[0mjit_compile\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjit_compile\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3885\u001b[0m             \u001b[0mexperimental_relax_shapes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexperimental_relax_shapes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3886\u001b[0;31m             experimental_follow_type_hints=experimental_follow_type_hints))\n\u001b[0m\u001b[1;32m   3887\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   3888\u001b[0m   \u001b[0;31m# This code path is for the `foo = tfe.defun(foo, ...)` use case\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/tf_decorator.py\u001b[0m in \u001b[0;36mmake_decorator\u001b[0;34m(target, decorator_func, decorator_name, decorator_doc, decorator_argspec)\u001b[0m\n\u001b[1;32m    105\u001b[0m   \u001b[0;31m# Keeping a second handle to `target` allows callers to detect whether the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m   \u001b[0;31m# decorator was modified using `rewrap`.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 107\u001b[0;31m   \u001b[0mdecorator_func\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__original_wrapped__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    108\u001b[0m   \u001b[0;32mreturn\u001b[0m \u001b[0mdecorator_func\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "def benchmark_keras(keras_model, ds_x, ds_y):\n",
    "    nb_samples = ds_x.shape[0]\n",
    "    acc = 0\n",
    "    acc_aug = 0\n",
    "    dur_tot = 0.\n",
    "    for x, y in tqdm(zip(ds_x, ds_y), total=nb_samples):\n",
    "        start_time = time.time()\n",
    "        output = keras_model.predict(np.array([x]))\n",
    "        dur_tot += time.time() - start_time\n",
    "        if np.argmax(output[0]) == np.argmax(y):\n",
    "            acc += 1\n",
    "    return acc / nb_samples, dur_tot / nb_samples\n",
    "\n",
    "accuracy, duration = benchmark_keras(model_keras, x_test, y_test_onehot)\n",
    "accuracy_aug, duration_aug = benchmark_keras(model_keras, x_test_aug, y_test_aug)\n",
    "\n",
    "print('--- Raw testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy))\n",
    "print('Inference time: {}ms'.format(int(duration*1000)))\n",
    "print('--- Partial input testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy_aug))\n",
    "print('Inference time: {}ms \\n'.format(int(duration_aug*1000)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb084a2f-e25d-44c7-bd3f-1dc23056b410",
   "metadata": {},
   "source": [
    "Save Keras model as standard TensorFlow format for TensorRT optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b4ed2c12-ee5e-40c0-aa64-95b088b4ed3f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ./Robust_BODY18_In/assets\n"
     ]
    }
   ],
   "source": [
    "model_keras.save('./Robust_BODY18_In')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44d4a52e-bd32-4d42-a244-7787694e7c97",
   "metadata": {},
   "source": [
    "## TF-TRT Comparisons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "816abd17-d8ff-434c-bd93-029ad89c8bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def benchmark_trt(inference_function, ds_x, ds_y):\n",
    "    nb_samples = ds_x.shape[0]\n",
    "    acc = 0\n",
    "    acc_aug = 0\n",
    "    dur_tot = 0.\n",
    "    for x, y in tqdm(zip(ds_x, ds_y), total=nb_samples):\n",
    "        x = tf.constant(np.expand_dims(x, axis=0), dtype=tf.float32)\n",
    "        start_time = time.time()\n",
    "        output = inference_function(x)\n",
    "        dur_tot += time.time() - start_time\n",
    "        if np.argmax(output['dense_20'][0]) == np.argmax(y):\n",
    "            acc += 1\n",
    "    return acc / nb_samples, dur_tot / nb_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03f9bc24-9ed3-4a86-859e-eb40a30ed72e",
   "metadata": {},
   "source": [
    "### FP32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6cfe485e-6376-4164-a9de-5253603ee7e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Linked TensorRT version: (7, 1, 3)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 1, 3)\n",
      "INFO:tensorflow:Assets written to: ./Robust_BODY18_FP32/assets\n"
     ]
    }
   ],
   "source": [
    "params_32 = tf.experimental.tensorrt.ConversionParams(\n",
    "    precision_mode='FP32',\n",
    "    max_workspace_size_bytes=(1<<25),\n",
    "    maximum_cached_engines=64\n",
    ")\n",
    "\n",
    "converter = trt.TrtGraphConverterV2(\n",
    "    input_saved_model_dir='./Robust_BODY18_In',\n",
    "    conversion_params=params_32)\n",
    "\n",
    "# Converter method used to partition and optimize TensorRT compatible segments\n",
    "converter.convert()\n",
    "\n",
    "# Optionally, build TensorRT engines before deployment to save time at runtime\n",
    "# Note that this is GPU specific, and as a rule of thumb, we recommend building at runtime\n",
    "def input_fn():\n",
    "    yield [np.random.normal(size=(1, 18, 2)).astype(np.float32)]\n",
    "\n",
    "converter.build(input_fn=input_fn)\n",
    "\n",
    "# Save the model to the disk \n",
    "converter.save('./Robust_BODY18_FP32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c21fa066-0b83-4fd1-9d22-bd484b4f6a9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_FP32 = tf.saved_model.load('./Robust_BODY18_FP32', tags=[tag_constants.SERVING])\n",
    "infer_FP32 = model_FP32.signatures['serving_default']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "76c91dbd-b0f8-4c8c-beb1-f1b1851a2756",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3000/3000 [00:23<00:00, 126.73it/s]\n",
      "100%|██████████| 3000/3000 [00:23<00:00, 127.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Raw testing dataset ---\n",
      "Accuracy: 98.37%\n",
      "Inference time: 6ms\n",
      "--- Partial input testing dataset ---\n",
      "Accuracy: 95.37%\n",
      "Inference time: 5ms \n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "accuracy, duration = benchmark_trt(infer_FP32, x_test, y_test_onehot)\n",
    "accuracy_aug, duration_aug = benchmark_trt(infer_FP32, x_test_aug, y_test_aug)\n",
    "\n",
    "print('--- Raw testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy))\n",
    "print('Inference time: {}ms'.format(int(duration*1000)))\n",
    "print('--- Partial input testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy_aug))\n",
    "print('Inference time: {}ms \\n'.format(int(duration_aug*1000)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f97a66a3-a57f-4791-9160-3001e13952ad",
   "metadata": {},
   "source": [
    "### FP16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b793b931-1d81-4261-bd4a-1fa3e40a5709",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_16 = tf.experimental.tensorrt.ConversionParams(\n",
    "    precision_mode='FP16',\n",
    "    max_workspace_size_bytes=(1<<25),\n",
    "    maximum_cached_engines=64\n",
    ")\n",
    "\n",
    "converter = trt.TrtGraphConverterV2(\n",
    "    input_saved_model_dir='./Robust_BODY18_In',\n",
    "    conversion_params=params_16)\n",
    "\n",
    "# Converter method used to partition and optimize TensorRT compatible segments\n",
    "converter.convert()\n",
    "\n",
    "# Optionally, build TensorRT engines before deployment to save time at runtime\n",
    "# Note that this is GPU specific, and as a rule of thumb, we recommend building at runtime\n",
    "def input_fn():\n",
    "    yield [np.random.normal(size=(1, 18, 2)).astype(np.float32)]\n",
    "\n",
    "converter.build(input_fn=input_fn)\n",
    "\n",
    "# Save the model to the disk \n",
    "converter.save('./Robust_BODY18_FP16')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c9ef1c5-666b-43bf-a22b-63667b13a3b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_FP16 = tf.saved_model.load('./Robust_BODY18_FP16', tags=[tag_constants.SERVING])\n",
    "infer_FP16 = model_FP16.signatures['serving_default']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf072704-8e85-4c28-a804-8ff02c76208a",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy, duration = benchmark_trt(infer_FP16, x_test, y_test_onehot)\n",
    "accuracy_aug, duration_aug = benchmark_trt(infer_FP16, x_test_aug, y_test_aug)\n",
    "\n",
    "print('--- Raw testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy))\n",
    "print('Inference time: {}ms'.format(int(duration*1000)))\n",
    "print('--- Partial input testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy_aug))\n",
    "print('Inference time: {}ms \\n'.format(int(duration_aug*1000)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "496f958c-a5fc-4ad4-aeb0-ddb46e62247c",
   "metadata": {},
   "source": [
    "### INT8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d60d9c8c-c212-4326-a693-4913ac4bce6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_8 = tf.experimental.tensorrt.ConversionParams(\n",
    "    precision_mode='INT8',\n",
    "    max_workspace_size_bytes=(1<<25),\n",
    "    maximum_cached_engines=64\n",
    ")\n",
    "\n",
    "converter = trt.TrtGraphConverterV2(\n",
    "    input_saved_model_dir='./Robust_BODY18_In',\n",
    "    conversion_params=params_8)\n",
    "\n",
    "\n",
    "# Optionally, build TensorRT engines before deployment to save time at runtime\n",
    "# Note that this is GPU specific, and as a rule of thumb, we recommend building at runtime\n",
    "def input_fn():\n",
    "    yield [np.random.normal(size=(1, 18, 2)).astype(np.float32)]\n",
    "\n",
    "# Converter method used to partition and optimize TensorRT compatible segments\n",
    "converter.convert(calibration_input_fn=input_fn)\n",
    "\n",
    "converter.build(input_fn=input_fn)\n",
    "\n",
    "# Save the model to the disk \n",
    "converter.save('./Robust_BODY18_INT8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e31f768-3245-4121-9fc7-3691e0f0c97f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_INT8 = tf.saved_model.load('./Robust_BODY18_INT8', tags=[tag_constants.SERVING])\n",
    "infer_INT8 = model_INT8.signatures['serving_default']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88e1b78c-0fca-4123-a641-9389708f190e",
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy, duration = benchmark_trt(infer_INT8, x_test, y_test_onehot)\n",
    "accuracy_aug, duration_aug = benchmark_trt(infer_INT8, x_test_aug, y_test_aug)\n",
    "\n",
    "print('--- Raw testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy))\n",
    "print('Inference time: {}ms'.format(int(duration*1000)))\n",
    "print('--- Partial input testing dataset ---')\n",
    "print('Accuracy: {:.2%}'.format(accuracy_aug))\n",
    "print('Inference time: {}ms \\n'.format(int(duration_aug*1000)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0e2fd1c-a464-4c56-a3f5-1b4e1a39899c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
